#Utility functions:

#' A function to generate a random tree using 'ape' package.
#' Returns a tree object.
#' @param LN The number of leafs to be in the tree. Default is 8.
RandTreeSim <- function(LN=8, furcation="binary"){
  if (furcation == "binary"){
    tree <- ape::rtree(n = LN, br = NULL)
    plot(tree, edge.width = 2)
    tree$edge
  } else if (furcation == "multi"){#Fix this.
    tiplabs <- paste("t", seq(1:LN), sep = "")
    while (length(tiplabs) > 0){
      sub <- sample(tiplabs, size = sample(seq(1:length(tiplabs)-1)), replace = F)
      print(sub)
      tiplabs <- tiplabs[which(!tiplabs %in% sub)]
    }
    tree <- ape::read.tree(text="(((L, K), E, F), (G, H));")
  }
  return(tree)
}

#' A function to downsample a refdata table based on classLabels
#' @param RefData Reference data.
#' @param ClassLabels the refdata storing class labels.
#' @param min_n min number of samples to downsample each class. default is 500.
#' @usage RefData_d <- DownSampleRef(RefData = RefData)
DownSampleIdx <- function(RefData, ClassLabels, min_n=500, ...){
  samp.idx <- NULL
  classes <- table(ClassLabels)
  if(is.null(min_n)){
    min_n <- min(classes)
  }
  for(type in names(classes)){
    if( classes[type] > min_n ){
      samp.idx <- c(samp.idx, sample(which(ClassLabels == type), size = min_n, replace = F))
    }else{
      samp.idx <- c(samp.idx, which(ClassLabels == type))
    }
  }
  return(samp.idx)
}

colMax <- function(data) sapply(data, max, na.rm = TRUE)
colMin <- function(data) sapply(data, min, na.rm = TRUE)

#' Calculates node out of bag accuracy
#' @param HieRMod hierfit model generated by CreateHier function
NodesAcc <- function(HieRMod, ...){
  #Extract the node accuracy metrics:
  nodeStats <- NULL
  for(i in names(HieRMod@model)){
    mtry <- HieRMod@model[[as.character(i)]]$finalModel$mtry
    nodeStats <- rbind(nodeStats,
                       cbind(node=i,
                             HieRMod@model[[as.character(i)]]$results[which(HieRMod@model[[as.character(i)]]$results$mtry == mtry),],
                             NodeLabel=HieRMod@tree[[1]]$node.label[ as.numeric(i) - length(HieRMod@tree[[1]]$tip.label)],
                             classSize=length(HieRMod@model[[as.character(i)]]$levels)))
  }
  nodeStats <- nodeStats[which(nodeStats$NodeLabel %in% HieRMod@tree[[1]]$node.label), ]
  rownames(nodeStats) <- nodeStats$NodeLabel
  nodeAcc <- round(nodeStats$Accuracy*100, digits = 1)
  names(nodeAcc) <- nodeStats$NodeLabel

  return(nodeAcc)
}

#' Digests tree into individual nodes.
#' @param tree in phylo format
DigestTree <- function(tree) {
  all.nodes <- unique(tree$edge[,1])
  return(all.nodes)
}

#' @param refMod reference hierfit model
#' @param Node one of the internal nodes on the tree
GetImportantFeatures <- function(refMod, Node){
  Node <- FixLab(Node)
  tree <- refMod@tree[[1]]
  labs_l <- c(tree$tip.label, tree$node.label)
  i <- match(Node, labs_l)
  plot <- varImpPlot(refMod@model[[as.character(i)]]$finalModel,
             n.var = 10,
             main = paste("Important predictors of node",
                          tree$node.label[i-length(tree$tip.label)],
                          sep="\n"))
  return(plot)
}

#' Extracts all of the features used to generate the hiermod
#' @param refMod hiermod
ExtractHierModfeatures <- function(refMod){
  node.list <- DigestTree(tree = refMod@tree[[1]])
  Tufs <- NULL
  for(i in node.list){
    Tufs <- c(Tufs, refMod@model[[as.character(i)]]$finalModel$xNames)
  }
  Tufs <- unique(Tufs)
  return(Tufs)
}

#' Saves the hiermod S4 object
#' @param refMod hiermod
#' @param filePrefix prefix for saving the object into wd().
SaveHieRMod <- function(refMod, filePrefix="Reference.HierMod"){
  lapply(refMod@mlr,
    function(x) rm(list=ls(envir = attr(x$terms, ".Environment")),
    envir = attr(x$terms, ".Environment")))
  lapply(refMod@model,
    function(x) environment(x$terms) <- NULL)
  saveRDS(refMod, file = paste(filePrefix,".RDS", sep = ""))
}

#' Depricated hiermod loading function
#' @param fileName for the saved object.
LoadHieRMod <- function(fileName){
  refMod <- readRDS(file = fileName)
  lapply(refMod@mlr,
         function(x) attr(x$terms, ".Environment") <- globalenv() )
  lapply(refMod@model,
    function(x) attr(x$terms, ".Environment") <- globalenv() )
  return(refMod)
}

#' A function to retrieve corresponding leafs of the child nodes of a given node.
#' @param tree A tree storing relatinship between the class labels.
#' @param node a particular non-terminal node in the tree.
GetChildNodeLeafs <- function(tree, node){
  #Get the children nodes.
  children <- tree$edge[which(tree$edge[, 1] == node), 2]
  #store ordered tips
  is_tip <- tree$edge[,2] <= length(tree$tip.label)
  ordered_tips <- tree$edge[is_tip, 2]
  #Then extract all leafs below these children
  c.tips <- list()
  for (c in children){
    if(c > length(tree$tip.label)){
      #print(paste("extracting tips for node", c, sep=" "))
      c.tips[[c]] <- ape::extract.clade(tree, c)$tip.label
    }else{
      c.tips[[c]] <- tree$tip.label[ordered_tips][match(c, ordered_tips)]
    }
  }
  return(c.tips)
}

#' Extracts the ancestral path of the given node upto the taxa root.
#' @param tree tree
#' @param class internal or leaf node
#' @param labels whether concatenate the class labels
GetAncestPath <- function(tree, class, labels=FALSE){
  path <- c()
  labs_l <- c(tree$tip.label, tree$node.label)
  Node <- match(class, labs_l)
  parent <- tree$edge[which(x = tree$edge[, 2] == Node), ][1]
  while(!is.na(parent)){
    if(labels){
    path <- c(path, class)
    }else{
      path <- c(path, paste(parent, class, sep = ""))
    }
    class <- labs_l[parent]
    parent <- tree$edge[which(x = tree$edge[, 2] == parent), ][1]
  }
  return(path)
}

#' A function to fix class labels.
#' @param xstring is a list of class labels in character.
FixLab <- function(xstring){
  #Replace white space with '_'
  xstring <- gsub(xstring, pattern = " ", replacement = "_")
  xstring <- gsub(xstring, pattern = "\\+|-|/", replacement = ".")
  xstring <- gsub(xstring, pattern = "`|,", replacement = "")
  return(xstring)
}

#' Randomly shuffles the rows and columns of the dataframe given
#' @param df input dataframe.
Shuffler <- function(df){
  dfr <- t(apply(df, 1, sample))
  dfr <- apply(dfr, 2, sample)
  colnames(dfr) <- colnames(df)
  rownames(dfr)<-rownames(df)
  return(as.data.frame(dfr))
}
#' Pulls the ortogolous genes table between the query species genes and reference species genes using biomartR
#' @param Genes_r list of reference genes.
#' @param species_r reference species.
#' @param species_q query species.
GetOrtologs <- function(Genes_r, species_r, species_q){
  ensembl.sp <- biomaRt::useMart("ensembl", dataset = paste(species_r, "gene_ensembl", sep = "_"),
                                 host = "www.ensembl.org",
                                 ensemblRedirect = FALSE)
  Ort <- biomaRt::getBM(attributes = c("external_gene_name",
                                       "ensembl_gene_id",
                                       paste(species_q, "homolog_associated_gene_name", sep = "_"),
                                       paste(species_q, "homolog_orthology_confidence", sep = "_"),
                                       paste(species_q, "homolog_orthology_type", sep = "_"),
                                       paste(species_q, "homolog_ensembl_gene", sep = "_") ),
                        filters = 'external_gene_name',
                        values = Genes_r,
                        uniqueRows = T,
                        mart = ensembl.sp)
  Ort <- Ort[which(Ort[, paste(species_q, "homolog_orthology_confidence", sep = "_")] == 1), ]
  Ort <- Ort[which(Ort[, paste(species_q, "homolog_orthology_type", sep = "_")] == "ortholog_one2one"), ]
  Ort[,'external_gene_name'] <- FixLab(Ort[,'external_gene_name'])
  Ort[, paste(species_q, "homolog_associated_gene_name", sep = "_")] <- FixLab(Ort[, paste(species_q, "homolog_associated_gene_name", sep = "_")])

  return(Ort)

}

#' Extracts the sibling nodes of a given node.
#' @param tree tree in phylo format
#' @param class node label
GetSiblings <- function(tree, class){

  labs_l <- c(tree$tip.label, tree$node.label)
  Node <- match(class, labs_l)
  parent <- tree$edge[which(x = tree$edge[, 2] == Node), ][1]
  #labs_s <- c(tree$tip.label, tree$node.label)
  labs_s <- labs_l[!labs_l %in% c(class, "TaxaRoot")]
  siblings <- c()
  for(cl in labs_s){
    Node.sib <- match(cl, labs_l)
    par.sib <- tree$edge[which(x = tree$edge[, 2] == Node.sib), ][1]
    if( par.sib == parent){
      #print(cl)
      siblings <- append(siblings, cl)
    }
  }
  return(siblings)
}

#' A function to determine the size of intersection between ancestors of True class and Predicted class
#' @param t true class
#' @param p predicted class
EvalPred <- function(t, p, tree){
  #Fix the label characters:
  t <- FixLab(t)
  p <- FixLab(p)
  #make a list of node indexes of the entire tree:
  labs_l <- c(tree$tip.label, tree$node.label)
  #look up the index of prior(t) and predicted(p) labels:
  Node.t <- match(t, labs_l)
  Node.p <- match(p, labs_l)
  #look up the parent node indexes of prior(t) and predicted(p) label indexes:
  parent.t <- tree$edge[which(x = tree$edge[, 2] == Node.t), ][1]
  parent.p <- tree$edge[which(x = tree$edge[, 2] == Node.p), ][1]
  #extract the list of children node indexes: #can be multiple children.
  children <- tree$edge[which(x = tree$edge[, 1] == Node.t), 2]

  if(t %in% tree$node.label){#if the prior node label is an internal node not a leaf.
  #extract the grandChildren node labels if exist.
  grandChilds <- c(ape::extract.clade(tree, t)$node.label, ape::extract.clade(tree, t)$tip.label)
  #Exclude children and self node labels.
  grandChilds <- grandChilds[!grandChilds %in% c(t, labs_l[children])]
  }else{
    grandChilds <- NULL
  }
  #Look up entire path for ancestors. This returns node index and node labels concatinated: e.g. "8B" "7A"
  Ancestors <- GetAncestPath(tree = tree, class = t)
  if(is.na(Node.p) || is.na(Node.t)){
    out <- "NotDefined"
  }else if(any(grep(p, Ancestors))){
    if(t == p){
      out <- "Correct_node"
    }else if(labs_l[parent.t] == p){
      out <- "Correct_parent_node"
    }else{
      out <- "Correct_ancestral_node"
    }
  }else{
    if(parent.t == parent.p){
      out <- "Incorrect_node_sibling"
    }else if(Node.p %in% children){
      out <- "Correct_children_node"
    }else if(p %in% grandChilds){
      out <- "Correct_grandchildren_node"
    }else{
    out <- "Incorrect_clade"
    }
  }
  return(out)
}

#' A function to determine the size of intersection between ancestors of True class and Predicted class
#' @param t true class
#' @param p predicted class
IntSectSize <- function(t, p, tree){
  Ti <- GetAncestPath(tree = tree, class = FixLab(t))
  Pi <- GetAncestPath(tree = tree, class = FixLab(p))
  intL <- length(intersect(Ti, Pi))
  return(intL)
}

#' A function for Hierarchical Precision, Recall, and F-measure.
#' @param tpT PriorPostTable: a table with two columns of which first is Prior and second is Post-prediction.
#' @param tree tree topology in phylo format.
#' @param BetaSq Beta coefficient
#' @param ND_term The label used for undetermined class types.
hPRF <- function(tpT, tree, BetaSq=1, ND_term="Undetermined"){
  # To Do: Consider Undetermined class!
  tpT$Int <- apply(tpT, 1, function(x) IntSectSize(t = x[1], p = x[2], tree = tree))
  tpT$PiL <- apply(tpT, 1, function(x) length(GetAncestPath(tree = tree, class = FixLab(x[2]) )))
  tpT$TiL <- apply(tpT, 1, function(x) length(GetAncestPath(tree = tree, class = FixLab(x[1]) )))

  hP <- sum(tpT$Int)/sum(tpT$PiL)
  hR <- sum(tpT$Int)/sum(tpT$TiL)
  hF <- (BetaSq+1)*hP*hR/(BetaSq*hP+hR)
  #Calculate ND rate
  ND.rate <- dim(tpT[tpT[, 2] == ND_term, ])[1]/dim(tpT)[1]
  #Filter out ND predictions
  tpT.size <- dim(tpT)[1]
  tpT <- tpT[tpT[, 2] != ND_term, ]
  #Calculate correctness
  metrics <- c("Correct_node", "Correct_parent_node", "Correct_ancestral_node",
  "Incorrect_node_sibling", "Correct_children_node", "Correct_grandchildren_node",
  "Incorrect_clade", "NotDefined")
  tpT$Eval <- apply(tpT, 1, function(x) EvalPred(t = x[1], p = x[2], tree = tree) )
  evals <- table(tpT$Eval)/tpT.size
  mm <- metrics[!metrics %in% names(evals)]
  mm.x <- rep(0, length(mm))
  names(mm.x) <- mm
  evals <- c(evals, mm.x)
  evals <- evals[order(names(evals))]
  return(c(Precision=hP, Recall=hR, Fmeasure=hF, evals, UndetectedRate=ND.rate))
}
